import os
import csv
from jieba import analyse


def stripdata(data_pth,out_pth):
    fp_movie = open(
        os.path.join(data_pth, "movie.csv"), mode="r", encoding="utf-8"
    )
    fp_keywords=open(os.path.join(data_pth,"Movie_tag.csv"),mode="r",encoding="gbk")
    fp_output=open(os.path.join(out_pth,"Movie_tag_add.csv"),mode="w",encoding="utf-8",newline="")
    
    header_out = ["id", "tag"]
    writer = csv.writer(fp_output)
    writer.writerow(header_out)
    print("strip")

    #读取关键词表
    keywords_csv=csv.reader(fp_keywords)
    headers_keywords=next(keywords_csv)
    print(headers_keywords)
    key_id=[]
    key_words=[]
    for row in keywords_csv:
        key_id.append(row[0])
        key_words.append([row[1]])
      
    #剧情简介提取关键词
    text_csv=csv.reader(fp_movie)
    next(text_csv)
    for row in text_csv:
        row_id=key_id.index(row[0])
        key_words_list=key_words[row_id][0].split(",")
        key_add=analyse.extract_tags(row[2],topK=5,withWeight=False,allowPOS=())
        for k in key_add:
            key_words_list.append(k)
        key_words[row_id][0]=",".join(key_words_list)       

    
    # 写回
    for i in range(len(key_id)):
        p=[key_id[i],key_words[i][0]]
        writer.writerow(p)
            
                

if __name__ == "__main__":
    pth = os.path.split(os.path.realpath(__file__))[0]
    data_pth = os.path.join(pth, "data")
    out_pth = os.path.join(pth, "output")

    stripdata(data_pth,out_pth)